"""Demonstration collection utils."""
from __future__ import annotations
import dataclasses
import datetime
import inspect
import json
import logging
from types import ModuleType

import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional, Union

from pathlib import Path
import importlib.util

import importlib.metadata

from gymnasium import spaces
from safetensors import safe_open

from bigym.action_modes import (
    JointPositionActionMode,
    ActionMode,
    PelvisDof,
    DEFAULT_DOFS,
)
from bigym.bigym_env import BiGymEnv
from bigym.robots import Robot
from bigym.utils.observation_config import ObservationConfig

from demonstrations.const import TRACKED_PACKAGES


def get_package_version(package_name: str) -> Optional[str]:
    """Get version of installed package."""
    try:
        return importlib.metadata.version(package_name)
    except importlib.metadata.PackageNotFoundError:
        return None


class ObservationMode(Enum):
    """Observation mode enum."""

    State = "state"
    Pixel = "pixel"
    Lightweight = "lightweight"


@dataclass
class Metadata:
    """BiGym demonstration metadata."""

    observation_mode: ObservationMode
    environment_data: EnvData
    seed: int
    reset_positions: list
    package_versions: dict[str, str]
    date: str = field(
        default_factory=lambda: datetime.datetime.now(datetime.timezone.utc).strftime(
            "%Y-%m-%d_%H-%M-%S"
        )
    )
    uuid: str = field(default_factory=lambda: uuid.uuid4().hex)

    @classmethod
    def from_safetensors(cls, demo_path: Path):
        """Get metadata from a safetensor file."""
        with safe_open(demo_path, framework="np", device="cpu") as f:
            metadata = f.metadata() or {}
        metadata = decode_safetensors_metadata(metadata)
        metadata["observation_mode"] = ObservationMode(metadata["observation_mode"])
        metadata["environment_data"] = EnvData.from_safetensors_metadata(
            metadata["environment_data"]
        )
        obj = cls(**metadata)
        obj._check_package_versions()
        return obj

    @classmethod
    def from_env(cls, env: BiGymEnv, is_lightweight: bool = False):
        """Create metadata from a BiGym environment."""
        if is_lightweight:
            observation_mode = ObservationMode.Lightweight
        elif env.observation_config.cameras:
            observation_mode = ObservationMode.Pixel
        else:
            observation_mode = ObservationMode.State
        package_versions = {}
        for package in TRACKED_PACKAGES:
            package_versions[package] = get_package_version(package)
        return cls(
            seed=env.seed,
            reset_positions=(env.robot.qpos_actuated * 0).tolist(),
            observation_mode=observation_mode,
            environment_data=get_env_data(env),
            package_versions=package_versions,
        )

    @classmethod
    def for_demo_store(
        cls,
        env_class: type[BiGymEnv],
        action_mode: type[ActionMode],
        floating_base: bool = True,
        floating_dofs: Optional[list[str]] = None,
        obs_mode: ObservationMode = ObservationMode.Lightweight,
        observation_config: ObservationConfig = ObservationConfig(),
        action_mode_absolute: Optional[bool] = True,
    ):
        """Create metadata for a demo store."""
        floating_dofs = DEFAULT_DOFS if floating_dofs is None else floating_dofs
        if obs_mode == ObservationMode.Pixel:
            if not observation_config.cameras:
                raise ValueError("Pixel observation mode requires cameras.")
        return cls(
            seed=0,
            reset_positions=[],
            observation_mode=obs_mode,
            environment_data=EnvData(
                env_name=env_class.__name__,
                action_mode_name=action_mode.__name__,
                floating_base=floating_base,
                floating_dofs=[dof.value for dof in floating_dofs],
                observation_config=observation_config,
                action_mode_absolute=action_mode_absolute,
            ),
            package_versions={},
        )

    def ready_for_safetensors(self) -> dict:
        """Prepare metadata for safetensors."""
        return {
            "seed": json.dumps(self.seed),
            "reset_positions": json.dumps(self.reset_positions),
            "observation_mode": json.dumps(self.observation_mode.value),
            "environment_data": json.dumps(dataclasses.asdict(self.environment_data)),
            "package_versions": json.dumps(self.package_versions),
            "date": json.dumps(self.date),
            "uuid": json.dumps(self.uuid),
        }

    def get_action_mode_description(self) -> str:
        """Get unified description of the action mode."""
        return self.environment_data.get_action_mode_description()

    def get_camera_description(self) -> str:
        """Get unified description of the cameras."""
        return self.environment_data.get_camera_description()

    @property
    def env_name(self) -> str:
        """Get environment name."""
        return self.environment_data.env_name

    @property
    def filename(self) -> str:
        """Create file name."""
        return f"{self.uuid}.safetensors"

    @property
    def floating_dof_count(self) -> int:
        """Count of floating DOFs."""
        if not self.environment_data.floating_base:
            return 0
        return len(self.environment_data.floating_dofs or DEFAULT_DOFS)

    def _check_package_versions(self):
        """Check if the package versions are consistent with the current environment."""
        for package, saved_version in self.package_versions.items():
            installed_version = get_package_version(package)
            if saved_version != installed_version:
                logging.warning(
                    f"Installed version of {package}: {installed_version} doesn't "
                    f"match version stored in demo file: {saved_version}. Demo replay "
                    "could be unstable."
                )

    def get_env(
        self, control_frequency: int, render_mode: Optional[str] = None
    ) -> BiGymEnv:
        """Get environment based on metadata."""
        if (
            env_class := find_class_in_module(
                "bigym.envs", self.environment_data.env_name
            )
        ) is None:
            raise ValueError(
                f"Invalid environment name provided: {self.environment_data.env_name}"
            )
        if not issubclass(env_class, BiGymEnv):
            raise ValueError(f"Invalid environment class provided: {env_class}")
        env = env_class(
            action_mode=self.get_action_mode(),
            observation_config=self.environment_data.observation_config,
            render_mode=render_mode,
            control_frequency=control_frequency,
        )
        return env

    def get_action_mode(self) -> ActionMode:
        """Get action mode based on metadata.

        Notes:
            - The action mode is not completely initialized
              until `ActionMode.bind_robot(robot)` is called.
        """
        action_mode_class = find_class_in_module(
            "bigym.action_modes", self.environment_data.action_mode_name
        )
        if action_mode_class is None:
            raise ValueError(
                "Invalid action mode name provided:"
                f"{self.environment_data.action_mode_name}"
            )
        if action_mode_class is JointPositionActionMode:
            action_mode = action_mode_class(
                absolute=self.environment_data.action_mode_absolute,
                floating_base=self.environment_data.floating_base,
                floating_dofs=[
                    PelvisDof(dof) for dof in self.environment_data.floating_dofs
                ],
            )
        else:
            action_mode = action_mode_class(
                floating_base=self.environment_data.floating_base,
                floating_dofs=[
                    PelvisDof(dof) for dof in self.environment_data.floating_dofs
                ],
            )
        return action_mode

    def get_action_space(self, action_scale: float) -> spaces.Box:
        """Get action space based on metadata."""
        robot = Robot(self.get_action_mode())
        return robot.action_mode.action_space(action_scale)


@dataclass
class EnvData:
    """BiGym environment data."""

    env_name: str
    action_mode_name: str
    floating_base: bool
    observation_config: ObservationConfig
    action_mode_absolute: Optional[bool] = None
    floating_dofs: list[str] = None

    @classmethod
    def from_safetensors_metadata(cls, metadata: dict):
        """Get metadata from a safetensor file."""
        metadata["observation_config"] = ObservationConfig.from_safetensors_metadata(
            metadata["observation_config"]
        )
        return cls(**metadata)

    def get_action_mode_description(self) -> str:
        """Get unified description of the action mode."""
        parts = [self.action_mode_name]
        if self.floating_base:
            parts.append("floating")
            if self.floating_dofs:
                parts.extend(self.floating_dofs)
        if self.action_mode_absolute is not None:
            parts.append("absolute" if self.action_mode_absolute else "delta")
        return "_".join(parts)

    def get_camera_description(self) -> str:
        """Get unified description of the cameras."""
        if not self.observation_config.cameras:
            return ""
        return "_".join(
            [camera.to_string() for camera in self.observation_config.cameras]
        )


def get_env_data(env: BiGymEnv) -> EnvData:
    """Get data about BiGym environment."""
    env_name = env.task_name
    action_mode_name = type(env.action_mode).__name__
    absolute = (
        env.action_mode.absolute
        if isinstance(env.action_mode, JointPositionActionMode)
        else None
    )
    floating_base = env.action_mode.floating_base
    floating_dofs = [dof.value for dof in env.action_mode.floating_dofs]
    observation_config = env.observation_config
    return EnvData(
        env_name=env_name,
        action_mode_name=action_mode_name,
        action_mode_absolute=absolute,
        floating_base=floating_base,
        floating_dofs=floating_dofs,
        observation_config=observation_config,
    )


def decode_safetensors_metadata(metadata: dict) -> dict:
    """Load metadata from a safetensor metadata dict recursively.

    Args:
        metadata (dict): Dictionary with metadata strings.

    Returns:
        dict: Dictionary with metadata.
    """
    for key, val in metadata.items():
        if isinstance(val, str):
            try:
                metadata[key] = json.loads(val)
            except ValueError:
                pass
        if isinstance(metadata[key], dict):
            metadata[key] = decode_safetensors_metadata(metadata[key])
    return metadata


def find_class_in_module(
    module: Union[str, ModuleType], class_name: str
) -> Optional[type]:
    """Find a class by its name in a directory.

    Args:
        module: Root module (e.g. "bigym.envs").
        class_name: Name of the class.

    Returns:
        Optional[type]: Class object or None if not found.
    """
    if isinstance(module, str):
        module = importlib.import_module(module)
    classes = inspect.getmembers(module, inspect.isclass)
    for name, cls in classes:
        if name == class_name:
            return cls
    parent_name = f"{module.__name__}."
    submodules = [
        submodule
        for _, submodule in inspect.getmembers(module, inspect.ismodule)
        if submodule.__name__.startswith(parent_name)
    ]
    for submodule in submodules:
        cls = find_class_in_module(submodule, class_name)
        if cls:
            return cls
    return None
